#include <cmath>
#include <iostream>
#include <vector>
#include <fstream>
#include <Eigen/LU>
#include <Eigen/Core>
#include <Eigen/Dense>
#include "TRandom3.h"
#include "TFile.h"
#include "TTree.h"
#include "TMath.h"
#include "SimpleTrack.h"


using namespace std;
using namespace Eigen;



#ifndef M_PI
#define M_PI           3.14159265358979323846
#endif


static inline double sign(double x)
{
  return ((double)(x > 0.)) - ((double)(x < 0.));
}


static inline int sign(int x)
{
  return ((x > 0) - (x < 0));
}


bool intersect_circles(bool hel, double startx, double starty, double rad_det, double k_trk, double phi_trk, double d_trk, double& x, double& y)
{
  double kcx = (d_trk*k_trk + 1.)*cos(phi_trk);
  double kcy = (d_trk*k_trk + 1.)*sin(phi_trk);
  double k2d2 = (d_trk*k_trk + 1.)*(d_trk*k_trk + 1.);
  double kd = sqrt(k2d2);
  if(kd > (k_trk*rad_det + 1.)){return false;}
  if(kd < fabs(k_trk*rad_det - 1.)){return false;}
  double kd_inv = 1./kd;
  double R2 = rad_det*rad_det;
  double a = 0.5*(k_trk*R2 + ( d_trk*d_trk*k_trk + 2.*d_trk ))*kd_inv;
  double tmp1 = a*kd_inv;
  double P2x = kcx*tmp1;
  double P2y = kcy*tmp1;
  
  double h = sqrt(R2 - a*a);
  
  double ux = -kcy*kd_inv;
  double uy = kcx*kd_inv;
  double P3x1 = P2x + ux*h;
  double P3y1 = P2y + uy*h;
  ux = -ux;
  uy = -uy;
  double P3x2 = P2x + ux*h;
  double P3y2 = P2y + uy*h;
  
  double d1_2 = (startx - P3x1)*(startx - P3x1) + (starty - P3y1)*(starty - P3y1);
  double d2_2 = (startx - P3x2)*(startx - P3x2) + (starty - P3y2)*(starty - P3y2);
  
//   if((d1_2 < d2_2) == hel)
  if(!hel)
  {
    x = P3x1;
    y = P3y1;
  }
  else
  {
    x = P3x2;
    y = P3y2;
  }
  
  return true;
}


bool intersect_helix_cylinder(bool hel, double startx, double starty, double startz, double rad_det, double k_trk, double phi_trk, double d_trk, double dzdl_trk, double& x, double& y, double& z)
{
  bool madeit = intersect_circles(hel, startx, starty, rad_det, k_trk, phi_trk, d_trk, x, y);
  if(madeit==false)
  {
    return false;
  }
  
  double k = k_trk;
  double D = sqrt((startx-x)*(startx-x) + (starty-y)*(starty-y));
  
  double s=0.;
  if(0.5*k*D > 0.1)
  {
    double v = 0.5*k*D;
    if(v >= 0.999999){v=0.999999;}
    s = 2.*asin(v)/k;
  }
  else
  {
    double temp1 = k*D*0.5;temp1*=temp1;
    double temp2 = D*0.5;
    s += 2.*temp2;
    temp2*=temp1;
    s += temp2/3.;
    temp2*=temp1;
    s += (3./20.)*temp2;
    temp2*=temp1;
    s += (5./56.)*temp2;
  }
  double dz = sqrt(s*s*dzdl_trk*dzdl_trk/(1. - dzdl_trk*dzdl_trk));
  if(dzdl_trk>0.){z = startz + dz;}
  else{z = startz - dz;}
}


//scatter helix about point x0,y0,z0, by spherical angle polar,azimuth (polar is total scattering angle, azimuth is angle about the tangent at x0,y0,z0)
void scatterTrack(double polar, double azimuth, double x0, double y0, double z0, double k_in, double phi_in, double d_in, double z0_in, double dzdl_in, double& k_out, double& phi_out, double& d_out, double& z0_out, double& dzdl_out)
{
  // change coordinates such that x0,y0,z0 -> 0,0,0
  double phi = atan2((1. + k_in*d_in)*sin(phi_in) - k_in*y0, (1. + k_in*d_in)*cos(phi_in) - k_in*x0);
  double px = cos(phi + M_PI/2.);
  double py = sin(phi + M_PI/2.);
  double pz = dzdl_in*sqrt(px*px + py*py)/sqrt(1. - dzdl_in*dzdl_in);
  
  double p = sqrt(px*px + py*py + pz*pz);
  double p_inv = 1./p;
  double ux = px*p_inv;
  double uy = py*p_inv;
  double uz = pz*p_inv;
  // rotate the u vector arbitrarily so that it is theta away from the original
  Vector3d u(ux,uy,uz);
  // find a vector not parallel to u
  // rotate about z by pi/2.  we are assuming u does not point along z
  MatrixXd rot_z = MatrixXd::Zero(3,3);
  rot_z(0,0) = 0.;  rot_z(0,1) = -1.;  rot_z(0,2) = 0.;
  rot_z(1,0) = 1.;  rot_z(1,1) = 0.;   rot_z(1,2) = 0.;
  rot_z(2,0) = 0.;  rot_z(2,1) = 0.;   rot_z(2,2) = 0.;
  Vector3d b = rot_z*u;
  // take the cross product of u and b to find a vector perpendicular to u
  Vector3d perp = u.cross(b);
  double perp_mag = sqrt(perp.dot(perp));
  double perp_mag_inv = 1./perp_mag;
  perp *= perp_mag_inv;
  // rotate by the polar angle about our axis perp
  MatrixXd rot_p = MatrixXd::Zero(3, 3); 
  double cos_p = cos(polar);
  double sin_p = sin(polar);
  rot_p(0,0) = cos_p + perp(0)*perp(0)*(1.-cos_p);
  rot_p(0,1) = perp(0)*perp(1)*(1.-cos_p) - perp(2)*sin_p;
  rot_p(0,2) = perp(0)*perp(2)*(1.-cos_p) + perp(1)*sin_p;
  rot_p(1,0) = perp(1)*perp(0)*(1.-cos_p) + perp(2)*sin_p;
  rot_p(1,1) = cos_p + perp(1)*perp(1)*(1.-cos_p);
  rot_p(1,2) = perp(1)*perp(2)*(1.-cos_p) - perp(0)*sin_p;
  rot_p(2,0) = perp(2)*perp(0)*(1.-cos_p) - perp(1)*sin_p;
  rot_p(2,1) = perp(2)*perp(1)*(1.-cos_p) + perp(0)*sin_p;
  rot_p(2,2) = cos_p + perp(2)*perp(2)*(1.-cos_p);
  b = rot_p*u;
  double vx = b(0);
  double vy = b(1);
  double vz = b(2);
  
  // rotate by the randomly generated angle azimuth around u
  MatrixXd rot = MatrixXd::Zero(3, 3); 
  double cos_az = cos(azimuth);
  double sin_az = sin(azimuth);
  rot(0,0) = cos_az + ux*ux*(1.-cos_az);
  rot(0,1) = ux*uy*(1.-cos_az) - uz*sin_az;
  rot(0,2) = ux*uz*(1.-cos_az) + uy*sin_az;
  rot(1,0) = uy*ux*(1.-cos_az) + uz*sin_az;
  rot(1,1) = cos_az + uy*uy*(1.-cos_az);
  rot(1,2) = uy*uz*(1.-cos_az) - ux*sin_az;
  rot(2,0) = uz*ux*(1.-cos_az) - uy*sin_az;
  rot(2,1) = uz*uy*(1.-cos_az) + ux*sin_az;
  rot(2,2) = cos_az + uz*uz*(1.-cos_az);
  MatrixXd vec = MatrixXd::Zero(3, 1);
  vec(0) = vx;
  vec(1) = vy;
  vec(2) = vz;
  MatrixXd rot_vec = MatrixXd::Zero(3, 1);
  
  rot_vec = rot*vec;
  // rot_vec is the tangent direction of the scattered helix at (0,0,0)
  // calculate parameters of new helix
  double k = k_in*sqrt( (rot_vec(0)*rot_vec(0) + rot_vec(1)*rot_vec(1))/(ux*ux + uy*uy) );
  double dzdl = rot_vec(2)/sqrt(rot_vec(0)*rot_vec(0) + rot_vec(1)*rot_vec(1) + rot_vec(2)*rot_vec(2));
  // phi changes by the kick in the xy plane
  double cosval = (rot_vec(0)*ux + rot_vec(1)*uy)/(sqrt((ux*ux + uy*uy)*(rot_vec(0)*rot_vec(0) + rot_vec(1)*rot_vec(1))));
  if(cosval > 1.){cosval = 1.;}
  else if(cosval < -1.){cosval = -1.;}
  double dphi = acos(cosval);
  // the sign of the cross product gives the direction of the phi kick
  dphi *= sign(rot_vec(0)*uy - rot_vec(1)*ux);
  phi += dphi;
  
  //convert back to coordinates 0,0,0 -> x0,y0,z0
  double cosphi = cos(phi);
  double sinphi = sin(phi);
  double tx = cosphi + k*x0;
  double ty = sinphi + k*y0;
  double dk = sqrt( tx*tx + ty*ty ) - 1.;
  if(k == 0.){d_out = x0*cosphi + y0*sinphi;}
  else{d_out = dk/k;}
  phi_out = atan2(ty, tx);
  cosphi = cos(phi_out);
  sinphi = sin(phi_out);
  k_out = k;
  dzdl_out = dzdl;
  // now solve for the new z0
  double dx = d_out*cosphi;
  double dy = d_out*sinphi;
  double D = sqrt( (x0-dx)*(x0-dx) + (y0-dy)*(y0-dy) );
  double s=0.;
  if(0.5*k_out*D > 0.1)
  {
    s = 2.*asin(0.5*k_out*D)/k_out;
  }
  else
  {
    double temp1 = k_out*D*0.5;temp1*=temp1;
    double temp2 = D*0.5;
    s += 2.*temp2;
    temp2*=temp1;
    s += temp2/3.;
    temp2*=temp1;
    s += (3./20.)*temp2;
    temp2*=temp1;
    s += (5./56.)*temp2;
  }
  double dz = sqrt(s*s*dzdl_out*dzdl_out/(1. - dzdl_out*dzdl_out));
  if(dzdl_out > 0.){z0_out = z0 - dz;}
  else{z0_out = z0 + dz;}
}


// pT in GeV/c , B in T , kappa in cm^-1
static inline double pT_to_kappa(double B, double pT)
{
  return 0.003*B/pT;
}


static inline double kappa_to_pT(double B, double kappa)
{
  return 0.003*B/kappa;
}


int shootTrack(double B, int charge, double mass, double px, double py, double pz, double vx, double vy, double vz, vector<double>& radii, vector<double>& rad_length, TRandom3& rand, vector<vector<double> >& intersections, double& k_out, double& d_out, double& phi_out, double& dzdl_out, double& z0_out)
{
  // calculate helix parameters after translating vx,vy,vz -> 0,0,0
  double pt = sqrt(px*px + py*py);
  double p = sqrt(px*px + py*py + pz*pz);
  double k = pT_to_kappa(B, pt);
  double dzdl = pz/p;
  double p_inv = 1./p;
  double ux = px*p_inv;
  double uy = py*p_inv;
  double uz = pz*p_inv;
  double phi_u = atan2(uy,ux);
  double dphi = -sign((double)(charge))*0.5*M_PI;
  double phi = phi_u + dphi;
  
  //convert back to coordinates 0,0,0 -> vx,vy,vz
  double cosphi = cos(phi);
  double sinphi = sin(phi);
  double tx = cosphi + k*vx;
  double ty = sinphi + k*vy;
  double dk = sqrt( tx*tx + ty*ty ) - 1.;
  double d = dk/k;
  phi = atan2(ty, tx);
  // now solve for the new z0
  double dx = d*cosphi;
  double dy = d*sinphi;
  double D = sqrt( (vx-dx)*(vx-dx) + (vy-dy)*(vy-dy) );
  double s=0.;
  if(0.5*k*D > 0.1)
  {
    s = 2.*asin(0.5*k*D)/k;
  }
  else
  {
    double temp1 = k*D*0.5;temp1*=temp1;
    double temp2 = D*0.5;
    s += 2.*temp2;
    temp2*=temp1;
    s += temp2/3.;
    temp2*=temp1;
    s += (3./20.)*temp2;
    temp2*=temp1;
    s += (5./56.)*temp2;
  }
  double dz = sqrt(s*s*dzdl*dzdl/(1. - dzdl*dzdl));
  double z0 = 0.;
  if(dzdl > 0.){z0 = vz - dz;}
  else{z0 = vz + dz;}
  
  // we now have k,d,phi,dzdl,z0
  
  k_out = k;
  d_out = d;
  phi_out = phi;
  dzdl_out = dzdl;
  z0_out = z0;
  
  // calculate the scattering widths
  double beta = p/sqrt(mass*mass + p*p);
  double chrg = sign((double)(charge))*((double)(charge));
  double sqrt_2 = 1.41421356237309515e+00;
  vector<double> scatter_width;
  for(unsigned int l=0;l<radii.size();++l)
  {
    double theta0 = sqrt_2*(0.0136*p_inv)*chrg*sqrt(rad_length[l]);
    scatter_width.push_back(theta0);
  }
  
  // now shoot through the layers
  double k_c = k;
  double d_c = d;
  double phi_c = phi;
  double dzdl_c = dzdl;
  double z0_c = z0;
  for(unsigned int l=0;l<radii.size();++l)
  {
    bool hel = (bool)((sign(charge) + 1)>>1);
    double x=0.;double y=0.;double z=0.;
    bool madeit = intersect_helix_cylinder(hel, vx, vy, vz, radii[l], k_c, phi_c, d_c, dzdl_c, x, y, z);
    
    if(madeit==false)
    {
      return l;
    }
    
//     cout<<hel<<" "<<vx<<" "<<vy<<" "<<vz<<" "<<radii[l]<<" "<<k_c<<" "<<phi_c<<" "<<d_c<<" "<<dzdl_c<<" "<<x<<" "<<y<<" "<<z<<endl;
    intersections[l][0] = x;
    intersections[l][1] = y;
    intersections[l][2] = z;
    double polar = fabs(rand.Gaus(0., scatter_width[l]));
    double azimuth = rand.Uniform(-M_PI, M_PI);
    double k_in = k_c;
    double phi_in = phi_c;
    double d_in = d_c;
    double dzdl_in = dzdl_c;
    double z0_in = z0_c;
    scatterTrack(polar, azimuth, x, y, z, k_in, phi_in, d_in, z0_in, dzdl_in, k_c, phi_c, d_c, z0_c, dzdl_c);
//     cout<<polar<<" "<<azimuth<<" "<<x<<" "<<y<<" "<<z<<" "<<k_in<<" "<<phi_in<<" "<<d_in<<" "<<z0_in<<" "<<dzdl_in<<" "<<k_c<<" "<<phi_c<<" "<<d_c<<" "<<z0_c<<" "<<dzdl_c<<endl<<endl;
  }
  return radii.size();
}


int main(int argc, char** argv)
{
  char* trackfile = argv[1];
  char* detfile = argv[2];
  char* outfile = argv[3];
  
  stringstream ss;
  double noise_proportion = 0.;
  ss.clear();ss.str("");ss<<argv[4];
  ss>>noise_proportion;
  
  
  TFile* ofile = new TFile(argv[3], "recreate");
  TTree* mc_tree = new TTree("events", "a tree of SimpleMCEvent.  The last track in each event just contains a list of noise hits");
  SimpleMCEvent* mcevent = new SimpleMCEvent();
  mc_tree->Branch("tracks", "SimpleMCEvent", &mcevent);
  
  
  
  
  
  double radius=0.;
  double radlen=0.;
  double xr=0.;
  double zr=0.;
  vector<double> radii;
  vector<double> rad_length;
  vector<double> x_res;
  vector<double> z_res;
  double sqrt_12_inv = 1./sqrt(12.);
  double B = 0.;
  ifstream det_in;
  det_in.open(detfile);
  det_in>>B>>radius>>radlen>>xr>>zr;
  while(det_in.good())
  {
    radii.push_back(radius);
    rad_length.push_back(radlen);
    x_res.push_back(xr*sqrt_12_inv);
    z_res.push_back(zr*sqrt_12_inv);
    det_in>>radius>>radlen>>xr>>zr;
  }
  
  int chrg=0;
  double mass=0.;
  double px=0.;
  double py=0.;
  double pz=0.;
  double vx=0.;
  double vy=0.;
  double vz=0.;
  
  ifstream in;
  in.open(trackfile);
  unsigned int ev = 0;
  unsigned int event = 0;
  in>>event>>chrg>>mass>>px>>py>>pz>>vx>>vy>>vz;
  vector<double> one_intersection;one_intersection.assign(3,0.);
  vector<vector<double> > intersections;intersections.assign(radii.size(), one_intersection);
  TRandom3 rand;
  unsigned int index = 0;
  while(in.good())
  {
    double trk_phi=0.;
    double trk_d=0.;
    double trk_kappa=0.;
    double trk_z0=0.;
    double trk_dzdl=0.;
    
    int nlayers = shootTrack(B, chrg, mass, px, py, pz, vx, vy, vz, radii, rad_length, rand, intersections, trk_kappa, trk_d, trk_phi, trk_dzdl, trk_z0);
    
    SimpleMCTrack onetrack;
    onetrack.phi = trk_phi;
    onetrack.d = trk_d;
    onetrack.kappa = trk_kappa;
    onetrack.z0 = trk_z0;
    onetrack.dzdl = trk_dzdl;
    
    bool spoiltrack = false;
    int rnd = ::rand() % 5 - 1;

    for(int l=0;l<nlayers;++l)
    {
	  //uncomment this line to remove hits in some layer
	  //spoiltrack = l==4;//rnd;
      if (spoiltrack) continue;
		
      SimpleMCHit onehit;
      onehit.layer = l;
      onehit.index = index;index+=1;
      double phi = atan2(intersections[l][1], intersections[l][0]);
      double sr = rand.Uniform(-x_res[l],x_res[l]);
//       double sr = rand.Gaus(0.0, x_res[l]);// TODO need an upper limit on this because it is always bounded by a few pixel lengths
      double x = intersections[l][0] + sr*cos(phi);
      double y = intersections[l][1] + sr*sin(phi);
      double sz = rand.Uniform(-z_res[l],z_res[l]);
//       double sz = rand.Gaus(0.0, z_res[l]);
      double z = intersections[l][2] + sz;
      onehit.x = x;
      onehit.y = y;
      onehit.z = z;
      
      onetrack.hits.push_back(onehit);
    }
    mcevent->tracks.push_back(onetrack);
    onetrack.hits.clear();
    
    in>>event>>chrg>>mass>>px>>py>>pz>>vx>>vy>>vz;
    if((event != ev) || (!(in.good())))
    {
      
      unsigned int noise_hits = (unsigned int)(noise_proportion*((double)(mcevent->tracks.size())));
      
      onetrack.phi = 0.;
      onetrack.d = 0.;
      onetrack.kappa = 0.;
      onetrack.z0 = 0.;
      onetrack.dzdl = 0.;
      
      for(unsigned int n=0;n<noise_hits;++n)
      {
        SimpleMCHit onehit;
        for(unsigned int l=0;l<radii.size();++l)
        {
          double phi = rand.Uniform(0., 2.*TMath::Pi());
          onehit.x = radii[l]*cos(phi);
          onehit.y = radii[l]*sin(phi);
          double dzdr = rand.Uniform(-0.5, 0.5);
          onehit.z = dzdr*radii[l];
          onehit.layer = l;
          onehit.index = index;index+=1;
          onetrack.hits.push_back(onehit);
        }
      }
      mcevent->tracks.push_back(onetrack);
      onetrack.hits.clear();
      
      ev = event;
      mc_tree->Fill();
      mcevent->tracks.clear();
      index=0;
    }
  }
  
  mc_tree->Write();
  ofile->Close();
  ofile->Delete();
  
  return 0;
}




